import random
import numpy as np


STABLE = 1e-10


class ORLC:
    def __init__(self,
                 n_state: int,
                 n_action: int,
                 n_episode: int,
                 n_step: int,
                 rho: float,
                 iota: float,
                 const: float):

        self.S = n_state
        self.A = n_action
        self.K = n_episode
        self.H = n_step
        self.delta = n_step
        self.iota = iota
        self.const = const * 10

        # initialize tables (row 1)
        self.V_table_up = np.zeros([self.H + 1, self.S])
        self.V_table_low = np.zeros([self.H + 1, self.S])
        self.Q_table_up = np.zeros([self.H + 1, self.S, self.A])
        self.Q_table_low = np.zeros([self.H + 1, self.S, self.A])
        self.N_table = np.zeros([self.S, self.A, self.S])
        self.R_table = np.zeros([self.S, self.A])     # reward
        self.R_var_table = np.zeros([self.S, self.A])

        for h in range(self.H):
            self.V_table_up[h] = self.H - h
            for s in range(self.S):
                self.Q_table_up[h][s] = self.H - h

        self.rho = rho  # rho-greedy

    def take_action(self,
                    state: int,
                    h: int,
                    is_train: bool) -> int:

        return np.argmax(self.Q_table_up[h][state])


    def get_certificates(self, state):
        low, high = self.V_table_low[0][state], self.V_table_up[0][state]
        return (low, high, high-low)


    def update(self, s0, a0, r, s1, h) -> None:
        self.N_table[s0][a0][s1] += 1 # row 8
        self.R_table[s0][a0] += (r - self.R_table[s0][a0]) / self.N_table[s0][a0].sum() # row 9


    def update_qv(self):
        const = 45 * self.S * (self.H ** 2) / self.const

        for h in range(self.H-1, -1, -1): # row 16
            Nhsa = self.N_table.sum(axis=-1) # (s * a)
            P_h = self.N_table / Nhsa[:, :, np.newaxis] # (s * a * s)
            P_h = np.nan_to_num(P_h)

            phi = np.minimum(np.sqrt((0.52 * (1.4 * np.log(np.log(np.maximum(3, Nhsa))) + self.iota) + STABLE )/ Nhsa) / 100, 1)
            theta = phi * (1 + np.sqrt((P_h @ (self.V_table_up[h + 1] ** 2).T + STABLE - (P_h @ self.V_table_up[h + 1].T) ** 2) * 12 + STABLE))
            theta += const * (phi ** 2)  # row 18, 2nd term
            theta += (1 / self.H) * P_h @ (self.V_table_up[h + 1] - self.V_table_low[h + 1]).T  # row 18, 3rd term

            self.Q_table_up[h] = np.clip(self.R_table + P_h @ (self.V_table_up[h + 1]).T + theta, None, self.H - h)  # (s * a) row 19
            self.Q_table_low[h] = np.clip(self.R_table + P_h @ (self.V_table_low[h + 1]).T - theta, 0, None)  # (s * a) row 20


            for s in range(self.S):
                a = np.argmax(self.Q_table_up[h][s])
                self.V_table_up[h][s] = self.Q_table_up[h][s][a]
                self.V_table_low[h][s] = self.Q_table_low[h][s][a]








